{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 10.3 word2vec的实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0.0\n"
     ]
    }
   ],
   "source": [
    "import collections\n",
    "import math\n",
    "import random\n",
    "import sys\n",
    "import time\n",
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.utils.data as Data\n",
    "\n",
    "sys.path.append(\"..\") \n",
    "import d2lzh_pytorch as d2l\n",
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.3.1 处理数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "assert 'ptb.train.txt' in os.listdir(\"../../data/ptb\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# sentences: 42068'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open('../../data/ptb/ptb.train.txt', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    # st是sentence的缩写\n",
    "    raw_dataset = [st.split() for st in lines]\n",
    "\n",
    "'# sentences: %d' % len(raw_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# tokens: 24 ['aer', 'banknote', 'berlitz', 'calloway', 'centrust']\n",
      "# tokens: 15 ['pierre', '<unk>', 'N', 'years', 'old']\n",
      "# tokens: 11 ['mr.', '<unk>', 'is', 'chairman', 'of']\n"
     ]
    }
   ],
   "source": [
    "for st in raw_dataset[:3]:\n",
    "    print('# tokens:', len(st), st[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.3.1.1 建立词语索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# tk是token的缩写\n",
    "counter = collections.Counter([tk for st in raw_dataset for tk in st])\n",
    "counter = dict(filter(lambda x: x[1] >= 5, counter.items()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# tokens: 887100'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx_to_token = [tk for tk, _ in counter.items()]\n",
    "token_to_idx = {tk: idx for idx, tk in enumerate(idx_to_token)}\n",
    "dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx]\n",
    "           for st in raw_dataset]\n",
    "num_tokens = sum([len(st) for st in dataset])\n",
    "'# tokens: %d' % num_tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.3.1.2 二次采样"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# tokens: 375647'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def discard(idx):\n",
    "    return random.uniform(0, 1) < 1 - math.sqrt(\n",
    "        1e-4 / counter[idx_to_token[idx]] * num_tokens)\n",
    "\n",
    "subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]\n",
    "'# tokens: %d' % sum([len(st) for st in subsampled_dataset])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# the: before=50770, after=2043'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def compare_counts(token):\n",
    "    return '# %s: before=%d, after=%d' % (token, sum(\n",
    "        [st.count(token_to_idx[token]) for st in dataset]), sum(\n",
    "        [st.count(token_to_idx[token]) for st in subsampled_dataset]))\n",
    "\n",
    "compare_counts('the')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# join: before=45, after=45'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compare_counts('join')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.3.1.3 提取中心词和背景词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_centers_and_contexts(dataset, max_window_size):\n",
    "    centers, contexts = [], []\n",
    "    for st in dataset:\n",
    "        if len(st) < 2:  # 每个句子至少要有2个词才可能组成一对“中心词-背景词”\n",
    "            continue\n",
    "        centers += st\n",
    "        for center_i in range(len(st)):\n",
    "            window_size = random.randint(1, max_window_size)\n",
    "            indices = list(range(max(0, center_i - window_size),\n",
    "                                 min(len(st), center_i + 1 + window_size)))\n",
    "            indices.remove(center_i)  # 将中心词排除在背景词之外\n",
    "            contexts.append([st[idx] for idx in indices])\n",
    "    return centers, contexts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]\n",
      "center 0 has contexts [1, 2]\n",
      "center 1 has contexts [0, 2]\n",
      "center 2 has contexts [0, 1, 3, 4]\n",
      "center 3 has contexts [1, 2, 4, 5]\n",
      "center 4 has contexts [3, 5]\n",
      "center 5 has contexts [4, 6]\n",
      "center 6 has contexts [4, 5]\n",
      "center 7 has contexts [8, 9]\n",
      "center 8 has contexts [7, 9]\n",
      "center 9 has contexts [7, 8]\n"
     ]
    }
   ],
   "source": [
    "tiny_dataset = [list(range(7)), list(range(7, 10))]\n",
    "print('dataset', tiny_dataset)\n",
    "for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):\n",
    "    print('center', center, 'has contexts', context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.3.2 负采样"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_negatives(all_contexts, sampling_weights, K):\n",
    "    all_negatives, neg_candidates, i = [], [], 0\n",
    "    population = list(range(len(sampling_weights)))\n",
    "    for contexts in all_contexts:\n",
    "        negatives = []\n",
    "        while len(negatives) < len(contexts) * K:\n",
    "            if i == len(neg_candidates):\n",
    "                # 根据每个词的权重（sampling_weights）随机生成k个词的索引作为噪声词。\n",
    "                # 为了高效计算，可以将k设得稍大一点\n",
    "                i, neg_candidates = 0, random.choices(\n",
    "                    population, sampling_weights, k=int(1e5))\n",
    "            neg, i = neg_candidates[i], i + 1\n",
    "            # 噪声词不能是背景词\n",
    "            if neg not in set(contexts):\n",
    "                negatives.append(neg)\n",
    "        all_negatives.append(negatives)\n",
    "    return all_negatives\n",
    "\n",
    "sampling_weights = [counter[w]**0.75 for w in idx_to_token]\n",
    "all_negatives = get_negatives(all_contexts, sampling_weights, 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.3.3 读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def batchify(data):\n",
    "    \"\"\"用作DataLoader的参数collate_fn: 输入是个长为batchsize的list, list中的每个元素都是__getitem__得到的结果\"\"\"\n",
    "    max_len = max(len(c) + len(n) for _, c, n in data)\n",
    "    centers, contexts_negatives, masks, labels = [], [], [], []\n",
    "    for center, context, negative in data:\n",
    "        cur_len = len(context) + len(negative)\n",
    "        centers += [center]\n",
    "        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]\n",
    "        masks += [[1] * cur_len + [0] * (max_len - cur_len)]\n",
    "        labels += [[1] * len(context) + [0] * (max_len - len(context))]\n",
    "    return (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives),\n",
    "            torch.tensor(masks), torch.tensor(labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "centers shape: torch.Size([512, 1])\n",
      "contexts_negatives shape: torch.Size([512, 60])\n",
      "masks shape: torch.Size([512, 60])\n",
      "labels shape: torch.Size([512, 60])\n"
     ]
    }
   ],
   "source": [
    "class MyDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, centers, contexts, negatives):\n",
    "        assert len(centers) == len(contexts) == len(negatives)\n",
    "        self.centers = centers\n",
    "        self.contexts = contexts\n",
    "        self.negatives = negatives\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        return (self.centers[index], self.contexts[index], self.negatives[index])\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.centers)\n",
    "\n",
    "batch_size = 512\n",
    "num_workers = 0 if sys.platform.startswith('win32') else 4\n",
    "\n",
    "dataset = MyDataset(all_centers, \n",
    "                    all_contexts, \n",
    "                    all_negatives)\n",
    "data_iter = Data.DataLoader(dataset, batch_size, shuffle=True,\n",
    "                            collate_fn=batchify, \n",
    "                            num_workers=num_workers)\n",
    "for batch in data_iter:\n",
    "    for name, data in zip(['centers', 'contexts_negatives', 'masks',\n",
    "                           'labels'], batch):\n",
    "        print(name, 'shape:', data.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.3.4 跳字模型\n",
    "### 10.3.4.1 嵌入层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[-2.8935,  1.9747, -0.2081, -0.6574],\n",
       "        [ 1.3135, -1.7396, -1.4210,  1.3302],\n",
       "        [-0.0465,  1.0802, -0.5344,  0.5250],\n",
       "        [-0.6899,  1.1832, -0.1694,  0.1382],\n",
       "        [-1.3940, -1.4121,  0.1867,  0.7681],\n",
       "        [ 0.2224, -0.3751,  0.5170,  0.1359],\n",
       "        [-1.4377,  0.4700,  0.5167,  0.8427],\n",
       "        [ 1.5523,  0.0542,  1.2034, -0.1215],\n",
       "        [-0.4874, -0.7876, -1.1580,  0.0728],\n",
       "        [-1.4077, -0.8691, -0.8106, -0.0612],\n",
       "        [-0.4633, -1.8948,  0.1791,  2.1354],\n",
       "        [ 0.4180,  1.3088,  1.2537,  2.0183],\n",
       "        [ 1.5453,  1.3754, -0.3551,  0.4333],\n",
       "        [ 1.7966, -0.2033, -0.5374, -0.0457],\n",
       "        [ 1.7540,  0.3209,  0.9063,  1.0655],\n",
       "        [-0.2148, -0.0743, -1.9261,  1.1415],\n",
       "        [-0.6571, -0.7888,  0.6224,  1.0660],\n",
       "        [-1.5191,  1.7596,  0.8295,  0.8935],\n",
       "        [ 0.4348, -0.2445, -0.6763,  1.5176],\n",
       "        [ 0.2910,  0.4196, -1.6204,  1.8422]], requires_grad=True)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embed = nn.Embedding(num_embeddings=20, embedding_dim=4)\n",
    "embed.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 1.3135, -1.7396, -1.4210,  1.3302],\n",
       "         [-0.0465,  1.0802, -0.5344,  0.5250],\n",
       "         [-0.6899,  1.1832, -0.1694,  0.1382]],\n",
       "\n",
       "        [[-1.3940, -1.4121,  0.1867,  0.7681],\n",
       "         [ 0.2224, -0.3751,  0.5170,  0.1359],\n",
       "         [-1.4377,  0.4700,  0.5167,  0.8427]]], grad_fn=<EmbeddingBackward>)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.long)\n",
    "embed(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.3.4.2 小批量乘法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1, 6])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.ones((2, 1, 4))\n",
    "Y = torch.ones((2, 4, 6))\n",
    "torch.bmm(X, Y).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.3.4.3 跳字模型前向计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def skip_gram(center, contexts_and_negatives, embed_v, embed_u):\n",
    "    v = embed_v(center)\n",
    "    u = embed_u(contexts_and_negatives)\n",
    "    pred = torch.bmm(v, u.permute(0, 2, 1))\n",
    "    return pred"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.3.5 训练模型\n",
    "### 10.3.5.1 二元交叉熵损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SigmoidBinaryCrossEntropyLoss(nn.Module):\n",
    "    def __init__(self): # none mean sum\n",
    "        super(SigmoidBinaryCrossEntropyLoss, self).__init__()\n",
    "    def forward(self, inputs, targets, mask=None):\n",
    "        \"\"\"\n",
    "        input – Tensor shape: (batch_size, len)\n",
    "        target – Tensor of the same shape as input\n",
    "        \"\"\"\n",
    "        inputs, targets, mask = inputs.float(), targets.float(), mask.float()\n",
    "        res = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction=\"none\", weight=mask)\n",
    "        return res.mean(dim=1)\n",
    "\n",
    "loss = SigmoidBinaryCrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.8740, 1.2100])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred = torch.tensor([[1.5, 0.3, -1, 2], [1.1, -0.6, 2.2, 0.4]])\n",
    "# 标签变量label中的1和0分别代表背景词和噪声词\n",
    "label = torch.tensor([[1, 0, 0, 0], [1, 1, 0, 0]])\n",
    "mask = torch.tensor([[1, 1, 1, 1], [1, 1, 1, 0]])  # 掩码变量\n",
    "loss(pred, label, mask) * mask.shape[1] / mask.float().sum(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8740\n",
      "1.2100\n"
     ]
    }
   ],
   "source": [
    "def sigmd(x):\n",
    "    return - math.log(1 / (1 + math.exp(-x)))\n",
    "\n",
    "print('%.4f' % ((sigmd(1.5) + sigmd(-0.3) + sigmd(1) + sigmd(-2)) / 4)) # 注意1-sigmoid(x) = sigmoid(-x)\n",
    "print('%.4f' % ((sigmd(1.1) + sigmd(-0.6) + sigmd(-2.2)) / 3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.3.5.2 初始化模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "embed_size = 100\n",
    "net = nn.Sequential(\n",
    "    nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size),\n",
    "    nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.3.5.3 定义训练函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train(net, lr, num_epochs):\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    print(\"train on\", device)\n",
    "    net = net.to(device)\n",
    "    optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "    for epoch in range(num_epochs):\n",
    "        start, l_sum, n = time.time(), 0.0, 0\n",
    "        for batch in data_iter:\n",
    "            center, context_negative, mask, label = [d.to(device) for d in batch]\n",
    "            \n",
    "            pred = skip_gram(center, context_negative, net[0], net[1])\n",
    "            \n",
    "            # 使用掩码变量mask来避免填充项对损失函数计算的影响\n",
    "            l = (loss(pred.view(label.shape), label, mask) *\n",
    "                 mask.shape[1] / mask.float().sum(dim=1)).mean() # 一个batch的平均loss\n",
    "            optimizer.zero_grad()\n",
    "            l.backward()\n",
    "            optimizer.step()\n",
    "            l_sum += l.cpu().item()\n",
    "            n += 1\n",
    "        print('epoch %d, loss %.2f, time %.2fs'\n",
    "              % (epoch + 1, l_sum / n, time.time() - start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train on cpu\n",
      "epoch 1, loss 1.97, time 74.53s\n",
      "epoch 2, loss 0.62, time 81.85s\n",
      "epoch 3, loss 0.45, time 74.49s\n",
      "epoch 4, loss 0.39, time 72.04s\n",
      "epoch 5, loss 0.37, time 72.21s\n",
      "epoch 6, loss 0.35, time 71.81s\n",
      "epoch 7, loss 0.34, time 72.00s\n",
      "epoch 8, loss 0.33, time 74.45s\n",
      "epoch 9, loss 0.32, time 72.08s\n",
      "epoch 10, loss 0.32, time 72.05s\n"
     ]
    }
   ],
   "source": [
    "train(net, 0.01, 10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.3.6 应用词嵌入模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cosine sim=0.478: hard-disk\n",
      "cosine sim=0.446: intel\n",
      "cosine sim=0.440: drives\n"
     ]
    }
   ],
   "source": [
    "def get_similar_tokens(query_token, k, embed):\n",
    "    W = embed.weight.data\n",
    "    x = W[token_to_idx[query_token]]\n",
    "    # 添加的1e-9是为了数值稳定性\n",
    "    cos = torch.matmul(W, x) / (torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9).sqrt()\n",
    "    _, topk = torch.topk(cos, k=k+1)\n",
    "    topk = topk.cpu().numpy()\n",
    "    for i in topk[1:]:  # 除去输入词\n",
    "        print('cosine sim=%.3f: %s' % (cos[i], (idx_to_token[i])))\n",
    "        \n",
    "get_similar_tokens('chip', 3, net[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
